{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "epochs = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 6 - Federated Learning on MNIST using a CNN\n",
    "\n",
    "## Upgrade to Federated Learning in 10 Lines of PyTorch + PySyft\n",
    "\n",
    "\n",
    "### Context \n",
    "\n",
    "Federated Learning is a very exciting and upsurging Machine Learning technique that aims at building systems that learn on decentralized data. The idea is that the data remains in the hands of its producer (which is also known as the _worker_), which helps improving privacy and ownership, and the model is shared between workers. One immediate application is for example to predict the next word on your mobile phone when you write text: you don't want the data used for training — i.e. your text messages — to be sent to a central server.\n",
    "\n",
    "The rise of Federated Learning is therefore tightly connected to the spread of data privacy awareness, and the GDPR in EU which enforces data protection since May 2018 has acted as a catalyst. To anticipate on regulation, large actors like Apple or Google have started investing massively in this technology, especially to protect the mobile users' privacy, but they have not made their tools available. At OpenMined, we believe that anyone willing to conduct a Machine Learning project should be able to implement privacy preserving tools with very little effort. We have built tools for encrypting data in a single line [as mentioned in our blog post](https://blog.openmined.org/training-cnns-using-spdz/) and we now release our Federated Learning framework which leverage the new PyTorch 1.0 version to provide an intuitive interface to building secure and scalable models.\n",
    "\n",
    "In this tutorial, we'll use directly [the canonical example of training a CNN on MNIST using PyTorch](https://github.com/pytorch/examples/blob/master/mnist/main.py) and show how simple it is to implement Federated Learning with it using our [PySyft library](https://github.com/OpenMined/PySyft/). We will go through each part of the example and underline the code which is changed.\n",
    "\n",
    "You can also find this material in [our blogpost](https://blog.openmined.org/upgrade-to-federated-learning-in-10-lines).\n",
    "\n",
    "Authors:\n",
    "- Théo Ryffel - GitHub: [@LaRiffle](https://github.com/LaRiffle)\n",
    "\n",
    "**Ok, let's get started!**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports and model specifications\n",
    "\n",
    "First we make the official imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And than those specific to PySyft. In particular we define remote workers `alice` and `bob`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy  # <-- NEW: import the Pysyft library\n",
    "hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")  # <-- NEW: define remote worker bob\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")  # <-- NEW: and alice"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define the setting of the learning task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Arguments():\n",
    "    def __init__(self):\n",
    "        self.batch_size = 64\n",
    "        self.test_batch_size = 1000\n",
    "        self.epochs = epochs\n",
    "        self.lr = 0.01\n",
    "        self.momentum = 0.5\n",
    "        self.no_cuda = False\n",
    "        self.seed = 1\n",
    "        self.log_interval = 30\n",
    "        self.save_model = False\n",
    "\n",
    "args = Arguments()\n",
    "\n",
    "use_cuda = not args.no_cuda and torch.cuda.is_available()\n",
    "\n",
    "torch.manual_seed(args.seed)\n",
    "\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "\n",
    "kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data loading and sending to workers\n",
    "We first load the data and transform the training Dataset into a Federated Dataset split across the workers using the `.federate` method. This federated dataset is now given to a Federated DataLoader. The test dataset remains unchanged."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader \n",
    "    datasets.MNIST('../data', train=True, download=True,\n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1307,), (0.3081,))\n",
    "                   ]))\n",
    "    .federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset\n",
    "    batch_size=args.batch_size, shuffle=True, **kwargs)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=False, transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1307,), (0.3081,))\n",
    "                   ])),\n",
    "    batch_size=args.test_batch_size, shuffle=True, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CNN specification\n",
    "Here we use exactly the same CNN as in the official example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 20, 5, 1)\n",
    "        self.conv2 = nn.Conv2d(20, 50, 5, 1)\n",
    "        self.fc1 = nn.Linear(4*4*50, 500)\n",
    "        self.fc2 = nn.Linear(500, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        x = x.view(-1, 4*4*50)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the train and test functions\n",
    "For the train function, because the data batches are distributed across `alice` and `bob`, you need to send the model to the right location for each batch. Then, you perform all the operations remotely with the same syntax like you're doing local PyTorch. When you're done, you get back the model updated and the loss to look for improvement."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(args, model, device, federated_train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset\n",
    "        model.send(data.location) # <-- NEW: send the model to the right location\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        model.get() # <-- NEW: get the model back\n",
    "        if batch_idx % args.log_interval == 0:\n",
    "            loss = loss.get() # <-- NEW: get the loss back\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,\n",
    "                100. * batch_idx / len(federated_train_loader), loss.item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The test function does not change!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(args, model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n",
    "            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability \n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Launch the training !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "model = Net().to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment\n",
    "\n",
    "for epoch in range(1, args.epochs + 1):\n",
    "    train(args, model, device, federated_train_loader, optimizer, epoch)\n",
    "    test(args, model, device, test_loader)\n",
    "\n",
    "if (args.save_model):\n",
    "    torch.save(model.state_dict(), \"mnist_cnn.pt\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Et voilà! Here you are, you have trained a model on remote data using Federated Learning!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## One Last Thing\n",
    "I know there's a question you're dying to ask: **how long does it takes to do Federated Learning compared to normal PyTorch?**\n",
    "\n",
    "The computation time is actually **less than twice the time** used for normal PyTorch execution! More precisely, it takes 1.9 times longer, which is very little compared to the features we were able to add."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "As you observe, we modified 10 lines of code to upgrade the official Pytorch example on MNIST to a real Federated Learning setting!\n",
    "\n",
    "Of course, there are dozen of improvements we could think of. We would like the computation to operate in parallel on the workers and to perform federated averaging, to update the central model every `n` batches only, to reduce the number of messages we use to communicate between workers, etc. These are features we're working on to make Federated Learning ready for a production environment and we'll write about them as soon as they are released!\n",
    "\n",
    "You should now be able to do Federated Learning by yourself! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways! \n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Pick our tutorials on GitHub!\n",
    "\n",
    "We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.\n",
    "\n",
    "- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)\n",
    "\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! \n",
    "\n",
    "- [Join slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! If you want to start \"one off\" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked `Good First Issue`.\n",
    "\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "- [Donate through OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
